👨‍💻 Getting Started with Few Shot Learning

Importing Python Libraries 📕 📗 📘 📙

In [1]:
#!pip install -r libraries_to_install.txt
In [2]:
import os
import random
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
from pathlib import Path
from statistics import mean
import torchvision.models as models
from easyfsl.methods.utils import evaluate
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import MultiStepLR
from easyfsl.methods import PrototypicalNetworks, FewShotClassifier
from easyfsl.modules import resnet12
from easyfsl.datasets import PLANT
from easyfsl.samplers import TaskSampler
from torch.utils.data import DataLoader
import matplotlib.image as mpimg
from PIL import Image
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")
In [3]:
if torch.cuda.is_available()==True:
    print('GPUs are available! ')
else:
    print('Please configure GPSs are not available')
GPUs are available! 

Exploratory data analysis 🔎 📊

  • Exploration of data is not neccessory for training the model but its a good practice to look at the dataset so that we can analyse that what type of data we are using and how we can handle it.

Sample images¶

In [4]:
img = mpimg.imread('./data/PLANT/100/DSC05982.jpg')
print(img.shape)
plt.imshow(img)
(256, 256, 3)
Out[4]:
<matplotlib.image.AxesImage at 0x1e3f3be6670>
In [5]:
img = mpimg.imread('./data/PLANT/256/DSC09062.jpg')
print(img.shape)
plt.imshow(img)
(256, 256, 3)
Out[5]:
<matplotlib.image.AxesImage at 0x1e3f3c65340>
In [6]:
img = mpimg.imread('./data/PLANT/316/DSC03840.jpg')
print(img.shape)
plt.imshow(img)
(256, 256, 3)
Out[6]:
<matplotlib.image.AxesImage at 0x1e3f3cdc520>
In [7]:
img = mpimg.imread('./data/PLANT/330/DSC06136.jpg')
print(img.shape)
plt.imshow(img)
(256, 256, 3)
Out[7]:
<matplotlib.image.AxesImage at 0x1e3f3d5b4c0>
In [8]:
img = mpimg.imread('./data/PLANT/348/DSC01037.jpg')
print(img.shape)
plt.imshow(img)
(256, 256, 3)
Out[8]:
<matplotlib.image.AxesImage at 0x1e3f5dce6d0>
In [9]:
img = mpimg.imread('./data/PLANT/370/DSC01163.jpg')
print(img.shape)
plt.imshow(img)
(256, 256, 3)
Out[9]:
<matplotlib.image.AxesImage at 0x1e3f5e41bb0>

Some random Imeges of class 110¶

In [10]:
images_data = []
for i in os.listdir('./data/PLANT/110/')[0:10]:
    split = i.split('_')
    images_data.append(Image.open('./data/PLANT/110/' + i))
plt.figure(figsize=(10,10))
for i in range(10):
    plt.subplot(5,2,i+1)
    plt.imshow(images_data[i])
plt.show()

Some random Imeges of class 150¶

In [11]:
images_data = []
for i in os.listdir('./data/PLANT/150/')[0:10]:
    split = i.split('_')
    images_data.append(Image.open('./data/PLANT/150/' + i))
plt.figure(figsize=(10,10))
for i in range(10):
    plt.subplot(5,2,i+1)
    plt.imshow(images_data[i])
plt.show()

Some random Imeges of class 200¶

In [12]:
images_data = []
for i in os.listdir('./data/PLANT/200/')[0:10]:
    split = i.split('_')
    images_data.append(Image.open('./data/PLANT/200/' + i))
plt.figure(figsize=(10,10))
for i in range(10):
    plt.subplot(5,2,i+1)
    plt.imshow(images_data[i])
plt.show()

Some random Imeges of class 370¶

In [13]:
images_data = []
for i in os.listdir('./data/PLANT/370/')[0:10]:
    split = i.split('_')
    images_data.append(Image.open('./data/PLANT/370/' + i))
plt.figure(figsize=(10,10))
for i in range(10):
    plt.subplot(5,2,i+1)
    plt.imshow(images_data[i])
plt.show()

Training Data

We use training data when we train the models. We feed train data to machine learning and deep learning models so that model can learn from the data.

Validation Data

We use validation data while training the model. We use this data to evalaute the performance that how the model perform on training time.

Testing Data

We use testing data after training the model. We use this data to evalaute the performance that how the model perform after training. So in this way first we get predictions from the trained model without giving the labels and then we compare the true labels with predictions and get the performance of th model..

Setting of classes and per class samples.¶

In [15]:
n_way = 271 # number of classes
n_shot = 1 #number of samples 
n_query = 1 # Number of images per class in the query set

DEVICE = "cuda"
n_workers = 5

Data Preparation of Training and Validation¶

In [16]:
n_tasks_per_epoch = 2000
n_validation_tasks = 2000

train_set = PLANT(split="train", training=True)
val_set = PLANT(split="test", training=False)

train_sampler = TaskSampler(
    train_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_tasks_per_epoch
)
val_sampler = TaskSampler(
    val_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
)

Training data loader¶

In [17]:
train_loader = DataLoader(
    train_set,
    batch_sampler=train_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=train_sampler.episodic_collate_fn,
)

Validation data loader¶

In [18]:
val_loader = DataLoader(
    val_set,
    batch_sampler=val_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=val_sampler.episodic_collate_fn,
)

👨Using the Mnasnet as transfer learning

In [19]:
tranfer_learning_model = models.mnasnet1_0(pretrained=True)

Using the tranfser learning in Few Shot Leanring¶

In [20]:
few_shot_classifier = PrototypicalNetworks(tranfer_learning_model).to(DEVICE)

Hyper parameter tuning for the Few Shot Learning model¶

In [21]:
LOSS_FUNCTION = nn.CrossEntropyLoss()
n_epochs = 50
scheduler_milestones = [20, 30]
scheduler_gamma = 0.1
learning_rate = 1e-2

train_optimizer = SGD(
    few_shot_classifier.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4
)
train_scheduler = MultiStepLR(
    train_optimizer,
    milestones=scheduler_milestones,
    gamma=scheduler_gamma,
)

Setting the function of epoch to monitor the real time logs¶

In [22]:
def training_epoch(
    model: FewShotClassifier, data_loader: DataLoader, optimizer: Optimizer
):
    all_loss = []
    model.train()
    with tqdm(
        enumerate(data_loader), total=len(data_loader), desc="Training"
    ) as tqdm_train:
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            _,
        ) in tqdm_train:
            optimizer.zero_grad()
            model.process_support_set(
                support_images.to(DEVICE), support_labels.to(DEVICE)
            )
            classification_scores = model(query_images.to(DEVICE))

            loss = LOSS_FUNCTION(classification_scores, query_labels.to(DEVICE))
            loss.backward()
            optimizer.step()

            all_loss.append(loss.item())

            tqdm_train.set_postfix(loss=mean(all_loss))

    return mean(all_loss)

Start training the Few Shot Learning model¶

In [23]:
best_state = few_shot_classifier.state_dict()
best_validation_accuracy = 0.0
for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    average_loss = training_epoch(few_shot_classifier, train_loader, train_optimizer)
    validation_accuracy = evaluate(
        few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
    )

    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        best_state = few_shot_classifier.state_dict()
    train_scheduler.step()
Epoch 0
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:06<00:00,  2.14it/s, loss=9.32]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:02<00:00,  5.82it/s, accuracy=0.243]
Epoch 1
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:02<00:00,  5.04it/s, loss=9.35]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:04<00:00,  3.19it/s, accuracy=0.291]
Epoch 2
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:03<00:00,  4.07it/s, loss=9.16]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:04<00:00,  2.84it/s, accuracy=0.292]
Epoch 3
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:04<00:00,  3.09it/s, loss=9.46]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:06<00:00,  2.32it/s, accuracy=0.297]
Epoch 4
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.40it/s, loss=9.7]
Validation: 100%|███████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.38it/s, accuracy=0.306]
Epoch 5
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:04<00:00,  2.97it/s, loss=8.52]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:06<00:00,  2.16it/s, accuracy=0.316]
Epoch 6
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.47it/s, loss=8.18]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.40it/s, accuracy=0.321]
Epoch 7
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.74it/s, loss=8.47]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:06<00:00,  2.25it/s, accuracy=0.323]
Epoch 8
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.46it/s, loss=8.22]
Validation: 100%|████████████████████████████████████████████████████████| 2000/2000 [00:06<00:00,  2.15it/s, accuracy=326]
Epoch 9
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:04<00:00,  2.93it/s, loss=8.62]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:07<00:00,  1.83it/s, accuracy=0.326]
Epoch 10
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.67it/s, loss=8.94]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:09<00:00,  1.44it/s, accuracy=0.331]
Epoch 11
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.71it/s, loss=7.2]
Validation: 100%|████████████████████████████████████████████████████████| 2000/2000 [00:09<00:00,  1.46it/s, accuracy=0.332]
Epoch 12
Training: 100%|██████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.60it/s, loss=7.14]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:08<00:00,  1.62it/s, accuracy=0.339]
Epoch 13
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.65it/s, loss=7.21]
Validation: 100%|███████████████████████████████████████████████████████| 2000/2000 [00:09<00:00,  1.44it/s, accuracy=0.349]
Epoch 14
Training: 100%|████████████████████████████████████████████████████████████████| 2000/2000 [00:04<00:00,  2.96it/s, loss=7.2]
Validation: 100%|███████████████████████████████████████████████████████| 2000/2000 [00:07<00:00,  1.87it/s, accuracy=0.349]
Epoch 15
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.55it/s, loss=7.39]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:07<00:00,  1.96it/s, accuracy=0.35]
Epoch 16
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.54it/s, loss=6.74]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:07<00:00,  1.99it/s, accuracy=0.351]
Epoch 17
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:04<00:00,  2.85it/s, loss=6.41]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:09<00:00,  1.45it/s, accuracy=0.353]
Epoch 18
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.70it/s, loss=6.72]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:08<00:00,  1.69it/s, accuracy=0.354]
Epoch 19
Training: 100%|██████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.48it/s, loss=6.8]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:18<00:00,  1.34s/it, accuracy=0.354]
Epoch 20
Training: 100%|██████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.46it/s, loss=6.23]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:08<00:00,  1.74it/s, accuracy=0.36]
Epoch 21
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.55it/s, loss=6.19]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:09<00:00,  1.41it/s, accuracy=0.368]
Epoch 22
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.61it/s, loss=5.2]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:08<00:00,  1.56it/s, accuracy=0.369]
Epoch 23
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.52it/s, loss=5.97]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:08<00:00,  1.62it/s, accuracy=0.371]
Epoch 24
Training: 100%|████████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.56it/s, loss=5.77]
Validation: 100%|████████████████████████████████████████████████████████| 2000/2000 [00:07<00:00,  1.94it/s, accuracy=0.376]
Epoch 25
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.42it/s, loss=5.36]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:08<00:00,  1.59it/s, accuracy=0.378]
Epoch 26
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.54it/s, loss=5.37]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:10<00:00,  1.31it/s, accuracy=0.38]
Epoch 27
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.61it/s, loss=4.26]
Validation: 100%|███████████████████████████████████████████████████████| 2000/2000 [00:09<00:00,  1.44it/s, accuracy=0.384]
Epoch 28
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.56it/s, loss=4.39]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:16<00:00,  1.20s/it, accuracy=0.405]
Epoch 29
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.52it/s, loss=4.77]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:10<00:00,  1.34it/s, accuracy=0.408]
Epoch 30
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.39it/s, loss=4.76]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:09<00:00,  1.46it/s, accuracy=0.421]
Epoch 31
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:06<00:00,  2.31it/s, loss=3.11]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:08<00:00,  1.74it/s, accuracy=0.427]
Epoch 32
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.48it/s, loss=3.94]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:10<00:00,  1.34it/s, accuracy=0.429]
Epoch 33
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:04<00:00,  2.80it/s, loss=3.32]
Validation: 100%|████████████████████████████████████████████████████████| 2000/2000 [00:08<00:00,  1.65it/s, accuracy=0.432]
Epoch 34
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.45it/s, loss=3.56]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:09<00:00,  1.41it/s, accuracy=0.443]
Epoch 35
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.53it/s, loss=3.46]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:09<00:00,  1.42it/s, accuracy=0.443]
Epoch 36
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.68it/s, loss=2.22]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:08<00:00,  1.73it/s, accuracy=0.445]
Epoch 37
Training: 100%|██████████████████████████████████████████████████████████████| 2000/2000 [00:04<00:00,  2.94it/s, loss=2.92]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:07<00:00,  1.77it/s, accuracy=0.452]
Epoch 38
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.56it/s, loss=1.34]
Validation: 100%|████████████████████████████████████████████████████████| 2000/2000 [00:07<00:00,  1.97it/s, accuracy=0.455]
Epoch 39
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.59it/s, loss=1.28]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:07<00:00,  1.83it/s, accuracy=0.477]
Epoch 40
Training: 100%|██████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.39it/s, loss=0.93]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:07<00:00,  1.85it/s, accuracy=0.478]
Epoch 41
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:04<00:00,  2.82it/s, loss=0.84]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:07<00:00,  1.78it/s, accuracy=0.479]
Epoch 42
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.65it/s, loss=0.81]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:19<00:00,  1.41s/it, accuracy=0.5]
Epoch 43
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.43it/s, loss=0.79]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:08<00:00,  1.65it/s, accuracy=0.506]
Epoch 44
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.53it/s, loss=0.78]
Validation: 100%|███████████████████████████████████████████████████████| 2000/2000 [00:08<00:00,  1.70it/s, accuracy=0.51]
Epoch 45
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.54it/s, loss=1.75]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:07<00:00,  1.94it/s, accuracy=0.515]
Epoch 46
Training: 100%|██████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.56it/s, loss=1.72]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:06<00:00,  2.03it/s, accuracy=0.519]
Epoch 47
Training: 100%|██████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.48it/s, loss=1.54]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:08<00:00,  1.65it/s, accuracy=0.522]
Epoch 48
Training: 100%|██████████████████████████████████████████████████████████████| 2000/2000 [00:05<00:00,  2.51it/s, loss=1.45]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:26<00:00,  1.86s/it, accuracy=0.523]
Epoch 49
Training: 100%|█████████████████████████████████████████████████████████████| 2000/2000 [00:06<00:00,  2.14it/s, loss=1.34]
Validation: 100%|██████████████████████████████████████████████████████| 2000/2000 [00:07<00:00,  1.95it/s, accuracy=0.529]

Using the best model state from the above experimnents¶

In [24]:
few_shot_classifier.load_state_dict(best_state)
Out[24]:
<All keys matched successfully>

Evaluation of trained model using test data¶

Loading the test data using data loader¶

In [25]:
n_test_tasks = 2000

test_set = PLANT(split="test", training=False)
test_sampler = TaskSampler(
    test_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_test_tasks
)
test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

Calculating the accuracy¶

In [26]:
accuracy = evaluate(few_shot_classifier, test_loader, device=DEVICE)
print(f"Average accuracy : {(100 * accuracy):.1f} %")
100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:08<00:00,  1.59it/s, accuracy=0.524]
Average accuracy : 52.4 %